import argparse
import time
from torch.optim import Adam
from chem.dataloader import DataLoaderMasking2
from loader_we import MoleculeDataset2
import torch
from tqdm import tqdm
import numpy as np
from main import CG, CosineDecayScheduler, D_CG
from torch.utils.tensorboard import SummaryWriter
import os


def train_mae(args, model, loader, device):
    dataname, lr, num_epochs, save_name, mm_l = args.dataset, args.lr, args.epochs, args.save_name, args.mm

    writer = SummaryWriter("./logs/pre_training/" + save_name + "/")
    lr_scheduler = CosineDecayScheduler(lr, 100, num_epochs)
    mm_scheduler = CosineDecayScheduler(1 - mm_l, 0, num_epochs)
    optimizer = Adam(model.trainable_parameters(), lr=lr, weight_decay=1e-4)
    print("开始进行训练......")
    best_loss = float('inf')
    avg_loss = 0
    for epoch in range(1, num_epochs + 1):
        model.train()
        # update learning rate
        lr = lr_scheduler.get(epoch - 1)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        mm = 1 - mm_scheduler.get(epoch - 1)
        pdar = tqdm(loader)
        losses = 0
        for gs in pdar:
            loss = model(gs, device)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            model.update_target_network(mm)
            losses += loss.item()
            pdar.set_description(f"Epoch: {epoch}, Loss: {round(loss.item(), 5)}")
        if epoch == 50 or epoch == 99:
            torch.save(model.state_dict(), f"result/pre_training/{save_name}/model_{epoch}_pre_training.pth")
        avg_loss = losses / len(loader)
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), f"./result/pre_training/{save_name}/{dataname}_best_model.pth")
        writer.add_scalar("ContrastiveTrain_Loss", round(avg_loss, 5), epoch)
    return round(avg_loss, 5)


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--save_name', type=str, default='')
    # zinc_standard_agent
    parser.add_argument('--dataset', type=str, default='zinc_standard_agent')
    parser.add_argument('--seed', type=int, default=0, help="Seed for splitting dataset.")
    parser.add_argument('--batch_size', type=int, default=256, help='input batch size for training (default: 256)')
    parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.001)')
    parser.add_argument('--mm', type=float, default=0.999)
    parser.add_argument('--num_layer', type=int, default=3, help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio', type=float, default=0.1, help='dropout ratio (default: 0)')
    parser.add_argument('--mask_rate', type=float, default=0.0, help='dropout ratio (default: 0.15)')
    parser.add_argument('--JK', type=str, default="last")
    # 默认1为不对官能团进行任何处理，2是保留长的官能团以及保留去重之后短的官能团，3是保留短的官能团，不遮蔽长的官能团
    parser.add_argument('--hander_mt', type=int, default=3, help='1,2,3')
    parser.add_argument('--num_workers', type=int, default=12, help='number of workers for dataset loading')
    parser.add_argument("--loss_fn", type=str, default="sce")
    parser.add_argument("--random_mask", type=bool, default=True)
    args = parser.parse_args()

    print(args)
    args_Dict = args.__dict__
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device("cuda:0")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    print("num layer: %d mask rate: %f" % (args.num_layer, args.mask_rate))
    new_folder_name = "Method_" + str(args.hander_mt) + "_" + str(args.dropout_ratio) + "dropout"
    if not os.path.exists(f"./result/pre_training/{new_folder_name}"):
        os.makedirs(f"./result/pre_training/{new_folder_name}")
        args.save_name = new_folder_name
    with open(f"./result/pre_training/{new_folder_name}/args.txt", 'w', encoding='utf-8') as f:
        for eachArg, value in args_Dict.items():
            f.writelines(eachArg + ' : ' + str(value) + '\n')
    dataset_name = args.dataset
    start = time.time()
    dataset = MoleculeDataset2("./dataset/" + dataset_name, dataset=dataset_name, mask_ratio=args.mask_rate,
                               hander_mt=args.hander_mt)
    loader = DataLoaderMasking2(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    print('~~~~using_time~~~~', round(time.time() - start, 6))
    model = CG(args.num_layer, 2, args.emb_dim, JK=args.JK, drop_ratio=args.dropout_ratio,mask_rate=args.mask_ratio).to(device)
    train_loss = train_mae(args, model, loader, device)
    print(train_loss)


if __name__ == "__main__":
    main()
